Skip to content

Add bypass distillation (blockwise local KD) to puzzletron pipeline#1111

Open
Separius wants to merge 5 commits intofeature/puzzletronfrom
ssameni/puzzletron-bypass
Open

Add bypass distillation (blockwise local KD) to puzzletron pipeline#1111
Separius wants to merge 5 commits intofeature/puzzletronfrom
ssameni/puzzletron-bypass

Conversation

@Separius
Copy link
Copy Markdown

@Separius Separius commented Mar 24, 2026

Bypass distillation trains alternative transformer block configurations using per-block knowledge distillation from the teacher model, producing a library of better "puzzle pieces" for the MIP solver. It is most beneficial for aggressive FFN pruning (≤ 1/8 of teacher width) or significant KV head compression.

Changes:

  • Add modelopt/torch/puzzletron/bypass_distillation/ module with full training loop, stitched model factory, checkpoint management, and data classes
  • Integrate bypass as optional Step 3 in puzzletron.py and puzzletron_nas_plugin.py (pipeline progress counter updates to 9 steps when bypass is enabled)
  • Add HuggingFace auto-download and skip-if-exists logic to puzzletron_nas_plugin.py for all pipeline steps
  • Add normalized_mse_loss, vectorwise_normalized_mse_loss, and batched_normalized_mse_loss to sewing_kit/utils.py
  • Fix child_init.py: support list of pruning mixins; fix None override treated as "keep original value" instead of raising TypeCheckError
  • Fix dataset.py: graceful fallback when tokenizer has no chat_template (base models)
  • Add _copy_auto_map_code_files to checkpoint_utils_hf.py so modeling Python files are copied alongside config.json (required for trust_remote_code checkpoints such as NemotronH)
  • Add create_train_dataloader to dataloaders.py
  • Add MoEChannelPruning to MlpInitMode enum
  • Add default pruning_mixins() to ModelDescriptor base class
  • Add NemotronHKVHeadsLayerDescriptor and kv_heads mixin to NemotronH descriptor; fix _set_keys_to_learn to skip Mamba blocks during subblock_attention bypass (based on block config)
  • Enable bypass in llama-3_1-8B_pruneffn_memory config; add example bypass/defaults.yaml
  • Update README with bypass documentation: when to use, time cost, sequential execution, W&B logging
  • Add unit tests for loss functions and distribution utilities
  • Add GPU integration tests for bypass (FFN pruning, KV compression, multi-config sweep, checkpoint validation)
  • Fix test_puzzletron.py assertion to handle variable GPU counts

Summary by CodeRabbit

  • New Features

    • Optional "Bypass Distillation" stage: blockwise local distillation with end-to-end training, stitching, checkpointing, new loss utilities, dataloaders, and improved dynamic progress reporting.
  • Documentation

    • Added BYPASS.md and updated example README to note bypass is disabled by default and how to enable/configure it.
  • Tests

    • New integration and unit tests for bypass flows, losses, utils, checkpointing and resume behavior.

Bypass distillation trains alternative transformer block configurations
using per-block knowledge distillation from the teacher model, producing
a library of better "puzzle pieces" for the MIP solver. It is most
beneficial for aggressive FFN pruning (≤ 1/8 of teacher width) or
significant KV head compression.

Changes:
- Add modelopt/torch/puzzletron/bypass_distillation/ module with full
  training loop, stitched model factory, checkpoint management, and
  data classes
- Integrate bypass as optional Step 3 in puzzletron.py and
  puzzletron_nas_plugin.py (pipeline progress counter updates to 9
  steps when bypass is enabled)
- Add HuggingFace auto-download and skip-if-exists logic to
  puzzletron_nas_plugin.py for all pipeline steps
- Add normalized_mse_loss, vectorwise_normalized_mse_loss, and
  batched_normalized_mse_loss to sewing_kit/utils.py
- Fix child_init.py: support list of pruning mixins; fix None override
  treated as "keep original value" instead of raising TypeCheckError
- Fix dataset.py: graceful fallback when tokenizer has no chat_template
  (base models)
- Add _copy_auto_map_code_files to checkpoint_utils_hf.py so modeling
  Python files are copied alongside config.json (required for
  trust_remote_code checkpoints such as NemotronH)
- Add create_train_dataloader to dataloaders.py
- Add MoEChannelPruning to MlpInitMode enum
- Add default pruning_mixins() to ModelDescriptor base class
- Add NemotronHKVHeadsLayerDescriptor and kv_heads mixin to
  NemotronH descriptor; fix _set_keys_to_learn to skip Mamba blocks
  during subblock_attention bypass (based on block config)
- Enable bypass in llama-3_1-8B_pruneffn_memory config; add example
  bypass/defaults.yaml
- Update README with bypass documentation: when to use, time cost,
  sequential execution, W&B logging
- Add unit tests for loss functions and distribution utilities
- Add GPU integration tests for bypass (FFN pruning, KV compression,
  multi-config sweep, checkpoint validation)
- Fix test_puzzletron.py assertion to handle variable GPU counts
@Separius Separius requested review from a team as code owners March 24, 2026 16:21
@Separius Separius requested a review from cjluo-nv March 24, 2026 16:21
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 24, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 24, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d0843d5f-59c7-4551-8a18-f119ef8a3ad3

📥 Commits

Reviewing files that changed from the base of the PR and between 346408b and 53f2a33.

📒 Files selected for processing (1)
  • modelopt/torch/puzzletron/bypass_distillation/training_loop.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt/torch/puzzletron/bypass_distillation/training_loop.py

📝 Walkthrough

Walkthrough

Adds an optional bypass (blockwise local) distillation stage: new bypass package with stitched teacher–student factory, distributed training loop and checkpointing, model/pruning extensions, normalized-MSE losses, dataloader helper, example configs/docs, and unit/GPU tests; integrates bypass into Puzzletron control flow.

Changes

Cohort / File(s) Summary
Bypass package & core logic
modelopt/torch/puzzletron/bypass_distillation/__init__.py, modelopt/torch/puzzletron/bypass_distillation/training_loop.py, modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py, modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py, modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py, modelopt/torch/puzzletron/bypass_distillation/data_classes.py
New bypass distillation package: entrypoint, sweep/run orchestration, stitched teacher↔student factory, distributed per-block training loop, checkpoint save/load utilities, experiment id/dir helpers, and dataclasses.
Pipeline integration & orchestrator
modelopt/torch/puzzletron/puzzletron.py, modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py, examples/puzzletron/main.py
Inserted optional bypass stage into main pipeline/NAS plugin; added _total_steps, dynamic progress reporting, restartable *.complete markers, HF auto-download path, and refactored setup with longer distributed timeouts.
Stitched-model & pruning extensions
modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py, modelopt/torch/puzzletron/anymodel/models/nemotron_h/...nemotron_h_model_descriptor.py, modelopt/torch/puzzletron/pruning/pruning_utils.py
Added ModelDescriptor.pruning_mixins() hook; Nemotron‑H KV‑heads layer descriptor and KV‑heads pruning mixin; added MoEChannelPruning enum and dispatch into MLP init/pruning flow.
Loss utilities
modelopt/torch/puzzletron/sewing_kit/utils.py, modelopt/torch/puzzletron/tools/kd_model.py
Added vectorwise_normalized_mse_loss and batched_normalized_mse_loss; adjusted normalized_mse_loss epsilon placement.
Data & dataloaders
modelopt/torch/puzzletron/utils/data/dataloaders.py, modelopt/torch/puzzletron/utils/data/dataset.py
Added create_train_dataloader (infinite ConstantLengthDataset-backed loader) and safer chat-template fallback for tokenizers without chat_template.
Stitch/child-init & checkpoint tweaks
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py, modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
Child init accepts list-of-mixins and avoids nulling overrides; checkpoint helper copies auto_map Python files into HF checkpoints.
Checkpoint helpers & utils
modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py, .../bypass_utils.py
Distributed-safe save/load of per-block stitched state, latest/best run discovery, and experiment-dir setup.
Examples, configs & docs
examples/puzzletron/BYPASS.md, examples/puzzletron/README.md, examples/puzzletron/configs/.../bypass/defaults.yaml, examples/puzzletron/configs/.../*.yaml
New BYPASS documentation, shared bypass defaults, many example Hydra configs updated to reference or document enabling bypass, and numerous small YAML composition files.
Tests — GPU & unit
tests/gpu/torch/puzzletron/test_bypass.py, tests/gpu/.../resources/.../bypass/test_bypass.yaml, tests/unit/torch/puzzletron/test_bypass_losses.py, tests/unit/torch/puzzletron/test_bypass_utils.py, tests/gpu/.../test_puzzletron.py
Added GPU integration tests for bypass workflows and checkpointing; unit tests for normalized-MSE losses and bypass utilities; adjusted distributed init timeouts and some test assertions.
Misc & small edits
various files (tests/*, modelopt/*, examples/*)
Import reorders, parsing/logging tweaks, minor test changes, and many YAML pointer files for config composition.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant Launcher as "launch_bypass_distillation(hydra_cfg)"
    participant Orchestrator as "run_bypassed_training(cfg)"
    participant Factory as "stitched_model_factory()"
    participant Data as "DataLoader / Teacher"
    participant Trainer as "train()"
    participant Checkpoint as "save_bypass_checkpoint()"

    User->>Launcher: provide Hydra cfg (single or sweep)
    Launcher->>Orchestrator: start run(s)
    Orchestrator->>Data: load teacher model & dataloaders
    Orchestrator->>Factory: build stitched teacher & student modules
    Factory-->>Orchestrator: return stitched modules + descriptors
    Orchestrator->>Trainer: start training loop
    loop per iteration
        Trainer->>Data: fetch batch
        Trainer->>Trainer: teacher forward -> capture activations
        Trainer->>Trainer: student forward -> compute per-block losses
        Trainer->>Trainer: backward, grad scale/clip, optimizer step
        Trainer->>Checkpoint: conditional save, write markers, symlink
        Checkpoint-->>Trainer: sync / resume info
    end
    Trainer-->>Orchestrator: training complete
    Orchestrator-->>Launcher: run finished
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 73.68% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title clearly and concisely describes the main feature being added: bypass distillation (blockwise local KD) to the puzzletron pipeline, which is the primary objective across all file changes.
Security Anti-Patterns ✅ Passed No security anti-patterns found: no torch.load() with weights_only=False, no numpy.load() with allow_pickle=True, no trust_remote_code=True, no eval/exec calls, no nosec comments.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch ssameni/puzzletron-bypass
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch ssameni/puzzletron-bypass

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tests/gpu/torch/puzzletron/test_puzzletron.py (1)

236-245: ⚠️ Potential issue | 🟡 Minor

The fallback printer still emits only rank-local values.

This branch now advertises num_layers={total_layers}, but it still prints only the contents of rank_{rank}.pth and is executed on rank 0 only. On multi-GPU runs the suggested EXPECTED_PRUNING_VALUES snippet will be incomplete.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/puzzletron/test_puzzletron.py` around lines 236 - 245, The
printer currently outputs only rank-local pruning_scores causing incomplete
EXPECTED_PRUNING_VALUES for multi-GPU runs; modify the logic so rank 0
aggregates pruning data from all ranks before printing: collect and merge
per-rank pruning_scores (or load all rank_{rank}.pth files) into a global
pruning_scores for each layer_name, compute the global score and channels (e.g.,
combine/average or gather channel indices across ranks) respecting total_layers,
and then have rank 0 iterate over layer_names using the aggregated values when
printing the block that uses total_layers and prints the EXPECTED_PRUNING_VALUES
snippet.
modelopt/torch/puzzletron/pruning/pruning_utils.py (1)

40-47: ⚠️ Potential issue | 🟠 Major

MoEChannelPruning is exposed before the init path supports it.

modelopt/torch/puzzletron/tools/bypassed_training/child_init.py now branches on this enum and forwards it into _init_mlp_module(), but _init_mlp_module() still falls through to Unsupported mlp_init_mode for this value when expert widths change. Any config that selects MoEChannelPruning will fail during child initialization.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/pruning/pruning_utils.py` around lines 40 - 47, The
enum MlpInitMode now includes MoEChannelPruning but _init_mlp_module still
treats that case as unsupported; update the _init_mlp_module implementation to
handle MlpInitMode.MoEChannelPruning (the same call-site that child_init.py
forwards into) by adding a branch for MlpInitMode.MoEChannelPruning that
performs the correct initialization when expert widths change (e.g., adapt the
weight/activation shapes by slicing/reshaping or reuse the
ConcatExpertsIntoDenseFFN logic where appropriate), so the child init no longer
falls through to the "Unsupported mlp_init_mode" error for MoEChannelPruning.
🧹 Nitpick comments (5)
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py (1)

804-806: This change makes explicit null resets impossible.

Treating None as “keep original” fixes the accidental overwrite, but it also removes the only way for JSON/YAML overrides to clear an optional field back to None. If callers need both behaviors, use a sentinel for “no override” and reserve None for explicit clearing.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py` around lines
804 - 806, The current override function (override) treats item_overrides ==
None as "keep original", which prevents callers from explicitly clearing a value
to None via JSON/YAML; change the logic to use a distinct sentinel (e.g., a new
unique object like NO_OVERRIDE) to represent "no override" and reserve None in
item_overrides to mean "set to None"/clear the field, updating the override
function to check against the sentinel (NO_OVERRIDE) instead of None and adjust
any callers that construct overrides to use the sentinel when they mean "leave
original".
modelopt/torch/puzzletron/utils/data/dataset.py (1)

123-130: Keep role markers in the no-template fallback.

Joining only content collapses system/user/assistant turns into plain text, which changes the supervision for chat datasets. A lightweight fallback like "{role}: {content}" preserves the conversation structure without relying on a tokenizer template.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/utils/data/dataset.py` around lines 123 - 130, The
fallback that builds sample when getattr(self.tokenizer, "chat_template", None)
is None should preserve role markers instead of joining only message["content"];
update the else branch in dataset.py (the block that currently sets sample =
"\n".join(m["content"] for m in sample)) to join messages using a lightweight
role-prefixed format like "{role}: {content}" so conversation turns
(system/user/assistant) are retained; keep using the same sample variable and
ensure this mirrors the structure expected by downstream code that consumes
apply_chat_template outputs.
modelopt/torch/puzzletron/utils/parsing.py (1)

337-345: Don’t silently treat every NaN as a no-op block.

This formatter now drops any NaN entry and can report No trainable blocks found. If a trainable block diverges, the failure disappears from the logs instead of surfacing. Filter only known skipped block types, or emit a separate warning for unexpected NaNs.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/utils/parsing.py` around lines 337 - 345, The
current filtering silently drops any NaN in losses_dict (and prunes
best_steps_dict/best_values_dict to match), which hides diverging trainable
blocks; instead, update the logic around losses_dict, best_steps_dict and
best_values_dict so you only drop entries whose keys match known skipped block
types (e.g., the explicit list of no-op block names like "Mamba"), and for any
other NaN values emit a warning/error (via the existing logger) that a trainable
block produced NaN rather than removing it; ensure best_steps_dict and
best_values_dict are only pruned to match the filtered losses_dict after this
selective filtering and warning behavior.
examples/puzzletron/main.py (1)

154-167: Progress messages in run_mip_only are hardcoded and inconsistent with the dynamic approach.

The run_full_puzzletron function now uses dynamic step counting (N = _total_steps(hydra_cfg)), but run_mip_only still uses hardcoded "7/8" and "8/8" progress messages. If bypass is configured, the step numbers would be incorrect (should be 8/9 and 9/9).

Consider applying the same dynamic step count logic here for consistency.

♻️ Suggested fix
 def run_mip_only(hydra_config_path: str):
     ...
     # Load hydra config
     hydra_cfg = initialize_hydra_config_for_dir(
         config_dir=hydra_config_dir,
         config_name=hydra_config_name,
         overrides=[],
     )
+    N = _total_steps(hydra_cfg)

     # Check if sweep mode is enabled
     if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False):
         mprint(
-            "Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)"
+            f"Puzzletron Progress {N-1}/{N}: running MIP sweep for multiple compression rates (multi-gpu)"
         )
         sweep.run_mip_sweep(hydra_cfg)
     else:
         # mip_and_realize_models (distributed processing)
         # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API
-        mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)")
+        mprint(f"Puzzletron Progress {N-1}/{N}: running MIP and realizing models (multi-gpu)")
         mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)

     dist.cleanup()
-    mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
+    mprint(f"Puzzletron Progress {N}/{N}: puzzletron pipeline completed (multi-gpu)")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/puzzletron/main.py` around lines 154 - 167, Update run_mip_only to
compute the total steps like run_full_puzzletron by calling
_total_steps(hydra_cfg) and use that N when formatting the progress messages
instead of hardcoded "7/8" and "8/8"; specifically, replace the two mprint calls
around the conditional that currently show "Puzzletron Progress 7/8" and "8/8"
with dynamic messages using N (e.g., f"Puzzletron Progress {current_step}/{N}:
...") and ensure current_step increments are correct for both the sweep branch
(sweep.run_mip_sweep) and the mip branch
(mip_and_realize_models.launch_mip_and_realize_model) so progress displays
consistently with _total_steps.
modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (1)

548-556: Unused variable num_trainable_params.

The variable num_trainable_params is computed but never used in this function or elsewhere. This appears to be residual code. Consider removing it to reduce unnecessary computation and improve code clarity.

♻️ Proposed removal
             assert "learning_rate" in cfg.training
-            num_trainable_params = sum(
-                p.requires_grad and submodule_name in p_name
-                for p_name, p in student_stitched_module.named_parameters()
-                if "dummy_param" not in p_name  # exclude placeholder params
-            )
-            # Do NOT enable dummy params: blocks with no real trainable parameters
-            # (e.g. Mamba blocks during an attention-only bypass run) should produce
-            # NaN loss so they are excluded from statistics — identical to the
-            # optimizer=None path in the training loop.

             student_module_parameters = {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`
around lines 548 - 556, Remove the unused computation of num_trainable_params:
delete the sum(...) assignment that iterates over
student_stitched_module.named_parameters() checking p.requires_grad and
submodule_name in p_name (and the "dummy_param" exclusion). Keep the surrounding
explanatory comment about dummy params if still relevant, but eliminate the dead
variable and its associated needless iteration to avoid wasted computation and
clarify stitched_model_factory.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`:
- Around line 45-58: The fallback currently only sorts checkpoint directories by
iteration (get_iter_num) so when multiple checkpoints exist for the same iter we
may pick an older step; update the selection to parse both iteration and step
and sort by (iter, step) descending. Modify get_iter_num (or replace with
get_iter_and_step) to extract iter via r"iter-(\d+)" and step via r"step-(\d+)"
(defaulting to 0 when not present), return a tuple of two ints, and use that
tuple as the sort key for checkpoint_dirs before iterating to find the first
directory with "saving_completed" and returning str(latest_dir).

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 673-677: Replace the hardcoded trust_remote_code=True in the
AutoTokenizer.from_pretrained call with the same caller-configurable
trust_remote_code flag you already read from the descriptor earlier (the
variable used for model config loading at lines ~597/631); specifically update
the tokenizer = AutoTokenizer.from_pretrained(...) invocation that uses
cfg.teacher_dir so it passes the descriptor-derived trust_remote_code value
instead of True, ensuring the flag remains configurable and defaults to False.

In `@modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py`:
- Around line 146-149: The pre-checks that treat presence of files like
(teacher_dir / "config.json"), any rank_*.pth, files under
pruned_ckpts_output_dir, or library outputs as sufficient to skip stages are
unsafe; change these guards to rely on durable completion markers (e.g., a .done
or .complete file) created at the successful end of
conversion/scoring/pruning/library build instead of existence-only checks, so
functions like the conversion branch around teacher_dir/config.json, the rank_*
checkpoint checks, and the pruned_ckpts_output_dir/library checks only skip when
their corresponding completion marker exists; ensure launch_score_activations()
remains the stricter gate for pruning-activation scoring but remove or weaken
the naive existence checks noted at the conversion lines (the block using
teacher_dir/config.json) and the other mentioned blocks (191-193, 286-289) to
check for the specific "<stage>.complete" marker before skipping.

In `@modelopt/torch/puzzletron/sewing_kit/utils.py`:
- Around line 452-454: The normalization denominator is computed as
F.mse_loss(target, torch.zeros_like(target) + epsilon, ...) which shifts the
target by epsilon and biases the scale; instead compute the denominator as
F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon (or
clamp_min the denominator to epsilon) so you add epsilon to the final scalar
denominator instead of to the zero tensor; update the occurrences around the
loss assignment (loss, input, target, epsilon, F.mse_loss) and the similar block
at lines 479-482 accordingly.

In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`:
- Around line 380-396: The auto_map parsing in checkpoint_utils_hf.py
incorrectly assumes each model_config.auto_map value is a dotted string; update
the logic that builds module_files (and any usage of class_ref) to first
normalize each value by: if it's a list/tuple take the first element, if it
contains a repo qualifier split off the "repo_id--" prefix, then take the module
part before the first '.' and append ".py" (so "tokenization_my.py"); apply this
normalization where module_files is created and when iterating filenames so
lists/tuples and repo-qualified references are handled and the correct source
filenames are copied.

In `@modelopt/torch/puzzletron/utils/data/dataloaders.py`:
- Around line 89-90: The DataLoader factory allows num_workers>0 while
ConstantLengthDataset.__iter__ does not shard via get_worker_info(), causing
duplicate samples; update the dataloaders (the function returning DataLoader in
modelopt/torch/puzzletron/utils/data/dataloaders.py) to either reject or
override num_workers>0 (e.g., force num_workers=0 or raise an error) until
ConstantLengthDataset.__iter__ is updated to use
torch.utils.data.get_worker_info() for worker sharding, and apply the same guard
to the other DataLoader creation call sites noted around the 114-118 region;
reference ConstantLengthDataset.__iter__ when adding the guard so future
implementers know to remove it after adding proper worker-aware iteration.
- Around line 98-99: The call to train_data.shuffle(seed=shuffle_seed,
keep_in_memory=True) fails for streaming (Iterable) datasets because
IterableDataset.shuffle() doesn't accept keep_in_memory; update the code that
checks shuffle_seed to detect streaming datasets (e.g., via whatever marker
load_streaming_fn sets or by checking hasattr(train_data, "__iter__") vs
__len__/isinstance of IterableDataset) and branch: for non-streaming datasets
call train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) as before, and
for streaming/iterable datasets call train_data.shuffle(seed=shuffle_seed)
without keep_in_memory; ensure you modify the block that references shuffle_seed
and train_data.shuffle so runtime errors are avoided when load_streaming_fn()
returns a streaming dataset.

In `@tests/gpu/torch/puzzletron/test_bypass.py`:
- Line 213: The timeout passed to dist.setup uses timedelta(10) which means 10
days; change it to an explicit unit like timedelta(seconds=10) (or
timedelta(minutes=10) if intended) to avoid 10-day test hangs — locate the call
to dist.setup (symbol: dist.setup) in tests/gpu/torch/puzzletron/test_bypass.py
and the other listed files and replace timedelta(10) with timedelta(seconds=10)
(or the correct unit) in each occurrence.

---

Outside diff comments:
In `@modelopt/torch/puzzletron/pruning/pruning_utils.py`:
- Around line 40-47: The enum MlpInitMode now includes MoEChannelPruning but
_init_mlp_module still treats that case as unsupported; update the
_init_mlp_module implementation to handle MlpInitMode.MoEChannelPruning (the
same call-site that child_init.py forwards into) by adding a branch for
MlpInitMode.MoEChannelPruning that performs the correct initialization when
expert widths change (e.g., adapt the weight/activation shapes by
slicing/reshaping or reuse the ConcatExpertsIntoDenseFFN logic where
appropriate), so the child init no longer falls through to the "Unsupported
mlp_init_mode" error for MoEChannelPruning.

In `@tests/gpu/torch/puzzletron/test_puzzletron.py`:
- Around line 236-245: The printer currently outputs only rank-local
pruning_scores causing incomplete EXPECTED_PRUNING_VALUES for multi-GPU runs;
modify the logic so rank 0 aggregates pruning data from all ranks before
printing: collect and merge per-rank pruning_scores (or load all rank_{rank}.pth
files) into a global pruning_scores for each layer_name, compute the global
score and channels (e.g., combine/average or gather channel indices across
ranks) respecting total_layers, and then have rank 0 iterate over layer_names
using the aggregated values when printing the block that uses total_layers and
prints the EXPECTED_PRUNING_VALUES snippet.

---

Nitpick comments:
In `@examples/puzzletron/main.py`:
- Around line 154-167: Update run_mip_only to compute the total steps like
run_full_puzzletron by calling _total_steps(hydra_cfg) and use that N when
formatting the progress messages instead of hardcoded "7/8" and "8/8";
specifically, replace the two mprint calls around the conditional that currently
show "Puzzletron Progress 7/8" and "8/8" with dynamic messages using N (e.g.,
f"Puzzletron Progress {current_step}/{N}: ...") and ensure current_step
increments are correct for both the sweep branch (sweep.run_mip_sweep) and the
mip branch (mip_and_realize_models.launch_mip_and_realize_model) so progress
displays consistently with _total_steps.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 548-556: Remove the unused computation of num_trainable_params:
delete the sum(...) assignment that iterates over
student_stitched_module.named_parameters() checking p.requires_grad and
submodule_name in p_name (and the "dummy_param" exclusion). Keep the surrounding
explanatory comment about dummy params if still relevant, but eliminate the dead
variable and its associated needless iteration to avoid wasted computation and
clarify stitched_model_factory.py.

In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 804-806: The current override function (override) treats
item_overrides == None as "keep original", which prevents callers from
explicitly clearing a value to None via JSON/YAML; change the logic to use a
distinct sentinel (e.g., a new unique object like NO_OVERRIDE) to represent "no
override" and reserve None in item_overrides to mean "set to None"/clear the
field, updating the override function to check against the sentinel
(NO_OVERRIDE) instead of None and adjust any callers that construct overrides to
use the sentinel when they mean "leave original".

In `@modelopt/torch/puzzletron/utils/data/dataset.py`:
- Around line 123-130: The fallback that builds sample when
getattr(self.tokenizer, "chat_template", None) is None should preserve role
markers instead of joining only message["content"]; update the else branch in
dataset.py (the block that currently sets sample = "\n".join(m["content"] for m
in sample)) to join messages using a lightweight role-prefixed format like
"{role}: {content}" so conversation turns (system/user/assistant) are retained;
keep using the same sample variable and ensure this mirrors the structure
expected by downstream code that consumes apply_chat_template outputs.

In `@modelopt/torch/puzzletron/utils/parsing.py`:
- Around line 337-345: The current filtering silently drops any NaN in
losses_dict (and prunes best_steps_dict/best_values_dict to match), which hides
diverging trainable blocks; instead, update the logic around losses_dict,
best_steps_dict and best_values_dict so you only drop entries whose keys match
known skipped block types (e.g., the explicit list of no-op block names like
"Mamba"), and for any other NaN values emit a warning/error (via the existing
logger) that a trainable block produced NaN rather than removing it; ensure
best_steps_dict and best_values_dict are only pruned to match the filtered
losses_dict after this selective filtering and warning behavior.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 371acd83-77b9-4396-8a82-eddd5b11dd40

📥 Commits

Reviewing files that changed from the base of the PR and between e508b76 and e018ca0.

📒 Files selected for processing (27)
  • examples/puzzletron/README.md
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/main.py
  • modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py
  • modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py
  • modelopt/torch/puzzletron/bypass_distillation/__init__.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/data_classes.py
  • modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
  • modelopt/torch/puzzletron/bypass_distillation/training_loop.py
  • modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
  • modelopt/torch/puzzletron/pruning/pruning_utils.py
  • modelopt/torch/puzzletron/puzzletron.py
  • modelopt/torch/puzzletron/sewing_kit/utils.py
  • modelopt/torch/puzzletron/tools/bypassed_training/child_init.py
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
  • modelopt/torch/puzzletron/utils/data/dataloaders.py
  • modelopt/torch/puzzletron/utils/data/dataset.py
  • modelopt/torch/puzzletron/utils/parsing.py
  • tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/bypass/test_bypass.yaml
  • tests/gpu/torch/puzzletron/test_bypass.py
  • tests/gpu/torch/puzzletron/test_puzzletron.py
  • tests/unit/torch/puzzletron/__init__.py
  • tests/unit/torch/puzzletron/test_bypass_losses.py
  • tests/unit/torch/puzzletron/test_bypass_utils.py

Comment on lines +45 to +58
# If "latest" doesn't exist, look explicitly into directories with `*iter-*`
candidate_dirs = [d for d in run_parent_dir.glob("*iter-*") if d.is_dir()]

if not candidate_dirs:
return None

def get_iter_num(dir_name):
match = re.search(r"iter-(\d+)", dir_name.name)
return int(match.group(1)) if match else 0

checkpoint_dirs = sorted(candidate_dirs, key=get_iter_num, reverse=True)
for latest_dir in checkpoint_dirs:
if (latest_dir / "saving_completed").exists():
return str(latest_dir)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Include step_num when picking the latest checkpoint.

This fallback only sorts on iter-(\d+). If a run writes multiple checkpoints inside the same iteration, resume can load an older step even though a newer checkpoint exists in the same run_parent_dir.

💡 Suggested fix
-    def get_iter_num(dir_name):
-        match = re.search(r"iter-(\d+)", dir_name.name)
-        return int(match.group(1)) if match else 0
-
-    checkpoint_dirs = sorted(candidate_dirs, key=get_iter_num, reverse=True)
+    def checkpoint_order(path: Path) -> tuple[int, int, float]:
+        match = re.search(r"iter-(\d+)(?:.*step-(\d+))?", path.name)
+        if not match:
+            return (0, 0, path.stat().st_mtime)
+        return (int(match.group(1)), int(match.group(2) or 0), path.stat().st_mtime)
+
+    checkpoint_dirs = sorted(candidate_dirs, key=checkpoint_order, reverse=True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`
around lines 45 - 58, The fallback currently only sorts checkpoint directories
by iteration (get_iter_num) so when multiple checkpoints exist for the same iter
we may pick an older step; update the selection to parse both iteration and step
and sort by (iter, step) descending. Modify get_iter_num (or replace with
get_iter_and_step) to extract iter via r"iter-(\d+)" and step via r"step-(\d+)"
(defaulting to 0 when not present), return a tuple of two ints, and use that
tuple as the sort key for checkpoint_dirs before iterating to find the first
directory with "saving_completed" and returning str(latest_dir).

Comment on lines +89 to +90
num_workers: int = 0,
) -> DataLoader:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

rg -n -C3 'class ConstantLengthDataset|def __iter__|get_worker_info|worker_info' \
  modelopt/torch/puzzletron/utils/data/dataset.py \
  modelopt/torch/puzzletron/utils/data/dataloaders.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1247


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Get the full __iter__ method implementation
echo "=== ConstantLengthDataset.__iter__() implementation ==="
sed -n '100,150p' modelopt/torch/puzzletron/utils/data/dataset.py

echo ""
echo "=== create_train_dataloader function signature and first 30 lines ==="
sed -n '80,120p' modelopt/torch/puzzletron/utils/data/dataloaders.py

echo ""
echo "=== Check for get_worker_info usage in dataset.py ==="
grep -n "get_worker_info" modelopt/torch/puzzletron/utils/data/dataset.py || echo "No get_worker_info found"

echo ""
echo "=== Check for any guards on num_workers in dataloaders.py ==="
grep -n "num_workers" modelopt/torch/puzzletron/utils/data/dataloaders.py

Repository: NVIDIA/Model-Optimizer

Length of output: 4020


num_workers is unsafe until ConstantLengthDataset is worker-sharded.

ConstantLengthDataset.__iter__() in modelopt/torch/puzzletron/utils/data/dataset.py walks self.dataset directly and never uses get_worker_info(). Passing num_workers > 0 to the DataLoader will cause each worker to read the same iterable stream, duplicating training samples silently.

Suggested guard
 def create_train_dataloader(
     bos_rate: float = 1.0,
     num_workers: int = 0,
 ) -> DataLoader:
     """Create an infinite training DataLoader over ConstantLengthDataset."""
+    if num_workers != 0:
+        raise ValueError(
+            "create_train_dataloader currently requires num_workers=0 because "
+            "ConstantLengthDataset is not sharded across DataLoader workers."
+        )
+
     if isinstance(dataset_path, str):
         dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory)

Also applies to: 114-118

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/utils/data/dataloaders.py` around lines 89 - 90,
The DataLoader factory allows num_workers>0 while ConstantLengthDataset.__iter__
does not shard via get_worker_info(), causing duplicate samples; update the
dataloaders (the function returning DataLoader in
modelopt/torch/puzzletron/utils/data/dataloaders.py) to either reject or
override num_workers>0 (e.g., force num_workers=0 or raise an error) until
ConstantLengthDataset.__iter__ is updated to use
torch.utils.data.get_worker_info() for worker sharding, and apply the same guard
to the other DataLoader creation call sites noted around the 114-118 region;
reference ConstantLengthDataset.__iter__ when adding the guard so future
implementers know to remove it after adding proper worker-aware iteration.

Comment on lines +98 to +99
if shuffle_seed is not None:
train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

Does Hugging Face datasets.IterableDataset.shuffle()accept akeep_in_memory keyword argument?

💡 Result:

No. datasets.IterableDataset.shuffle() only takes seed, generator, and buffer_size (no keep_in_memory kwarg), per both the API reference and the implementation. [1], [2]


Branch the shuffle call for streaming datasets.

IterableDataset.shuffle() does not accept keep_in_memory=True. When load_streaming_fn() is used (streaming datasets), line 99 will fail at runtime with an unexpected keyword argument error.

Suggested fix
     train_data = dataset[dataset_name]
     if shuffle_seed is not None:
-        train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True)
+        if isinstance(train_data, datasets.IterableDataset):
+            train_data = train_data.shuffle(seed=shuffle_seed)
+        else:
+            train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/utils/data/dataloaders.py` around lines 98 - 99,
The call to train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) fails for
streaming (Iterable) datasets because IterableDataset.shuffle() doesn't accept
keep_in_memory; update the code that checks shuffle_seed to detect streaming
datasets (e.g., via whatever marker load_streaming_fn sets or by checking
hasattr(train_data, "__iter__") vs __len__/isinstance of IterableDataset) and
branch: for non-streaming datasets call train_data.shuffle(seed=shuffle_seed,
keep_in_memory=True) as before, and for streaming/iterable datasets call
train_data.shuffle(seed=shuffle_seed) without keep_in_memory; ensure you modify
the block that references shuffle_seed and train_data.shuffle so runtime errors
are avoided when load_streaming_fn() returns a streaming dataset.

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary: Adds bypass distillation (blockwise local knowledge distillation) as an optional pipeline stage to puzzletron. Includes a full training loop, stitched model factory, checkpoint management, loss functions, data loader, configuration, and comprehensive tests. Also fixes bugs in child_init.py, dataset.py, and adds HF auto-download logic.

Issues Found:

  1. [Duplicated Code] normalized_mse_loss in sewing_kit/utils.py (diff lines 432-445) is an exact duplicate of the existing implementation in modelopt/torch/puzzletron/tools/kd_model.py:32-41. The new code should import and reuse the existing function rather than redefining it. The vectorwise_normalized_mse_loss and batched_normalized_mse_loss variants are new and fine, but they should build on the existing import.

  2. [Correctness / Security] training_loop.py:675AutoTokenizer.from_pretrained uses hardcoded trust_remote_code=True. The variable trust_remote_code is already computed from the descriptor at line 648. This should use trust_remote_code=trust_remote_code instead. (Flagged by pre-merge checks as well.)

  3. [Correctness / Security] bypass_checkpoint_utils.py:85,99torch.load() calls lack weights_only=True. The codebase convention (e.g., checkpoint_utils.py:43,77) is to use weights_only=True for state dict loading. These calls load state dicts and optimizer states respectively, which are pure tensor data and should use weights_only=True.

  4. [Correctness] training_loop.py — The except Exception as e block at the end of run_bypassed_training (around line 870) catches all exceptions and calls sys.exit(1) for non-SystemExit exceptions. This swallows the actual exception type and prevents proper test framework error reporting. In GPU tests, a failing bypass run will produce SystemExit(1) instead of the real traceback. Consider re-raising or at least logging before exit.

  5. [Correctness] stitched_model_factory.py:370-373 — The lambda closures in the stitched module creation loop (adapter=lambda v: InputArgs(target=v) and adapter=lambda v: InputArgs(input=v)) capture v correctly since they're arguments, but the loss target/input naming ("target" and "input") relies on block_loss_func accepting exactly these keyword arguments. If someone changes block_loss_func to e.g. batched_normalized_mse_loss, the keyword args don't match (batched_normalized_mse_loss takes input and target positional args, not kwargs via InputArgs). This coupling is implicit and fragile — consider documenting the contract or adding a **kwargs adapter.

  6. [Correctness] bypass_checkpoint_utils.py:89loaded_state_dict = {**stitched_module.state_dict(), **loaded_state_dict} merges the current state dict with the loaded one (loaded takes precedence). However, the current state dict is fetched before loading — if the model is on a different device, keys may contain tensors on the wrong device. The subsequent load_state_dict should handle this, but the intermediate merged dict is wasteful. Consider just using strict=False with load_state_dict directly.

  7. [Readability] stitched_model_factory.py — The bypass_factory_fn function is ~250 lines long with deeply nested logic. The student model initialization block (lines 200-305) could be extracted into a helper like _initialize_student_model(...).

  8. [Readability] training_loop.py — The train() function is ~300 lines with deeply nested control flow for logging, validation, checkpoint saving, and time-based signals. Consider extracting checkpoint-save logic and logging logic into separate functions.

  9. [Readability] stitched_model_factory.py:434-435 — Blank lines between the closing of the function and the backward-compatible aliases (gqa_factory_fn = bypass_factory_fn, moe_factory_fn = bypass_factory_fn). These aliases have no callers in this PR and no documentation. If they're for backward compat with existing configs, add a comment. If they're unused, remove them.

  10. [Tests] The GPU tests are thorough for the happy path but don't test checkpoint resume (loading from a previous run). The find_last_ckpt_for_resume + load_local_state path is complex and untested. At minimum, a test that runs bypass, then runs it again with find_last_ckpt_for_resume=True to verify resume works would increase confidence.

  11. [Tests] No unit test for _set_keys_to_learn which has significant branching logic (subblock types, hybrid model block_configs filtering, regex fallback). This function is critical for correctness.

  12. [Correctness] puzzletron_nas_plugin.py — The new auto-download logic in convert_puzzletron_model (lines 152-165) runs snapshot_download only on rank 0 inside if dist.is_master(), but then all ranks call dist.barrier(). If the download takes a long time, the barrier timeout (set in main.py as timedelta(10) = 10 days) should be fine, but the input_model_path variable is only updated on rank 0 — other ranks never use it since only rank 0 does the conversion. This is correct but subtle; a comment would help.

  13. [Correctness] bypass_utils.py:50set_experiment_dir assigns a Path object to cfg.bypass.experiment_dir, but OmegaConf/DictConfig doesn't natively support Path objects. This works because OmegaConf stores it as-is in struct mode off, but it may cause serialization issues (e.g., json_dump in save_bypass_checkpoint). Consider converting to str.

Suggestions:

  • The _copy_auto_map_code_files addition in checkpoint_utils_hf.py is a good fix for trust_remote_code models. Consider adding a brief unit test or at least a comment about which models require this (e.g., NemotronH).
  • The format_stitched_losses NaN filtering is a nice quality-of-life improvement for hybrid models. The import math inside the function body should be moved to the module top-level.
  • The dataset.py chat_template fallback is correct and handles base models gracefully.
  • The child_init.py fix (return item instead of return item_overrides when None) is a real bug fix — good catch.

Overall Assessment: This is a well-structured, substantial feature addition. The core architecture (stitched model factory, per-block KD, pipeline integration) is sound. However, the hardcoded trust_remote_code=True security issue and the duplicated normalized_mse_loss need to be addressed before merge. The torch.load calls should also use weights_only=True per project convention.

- Fix realize_best_or_latest: add find_best_run_dir() and update
  realize_bypass_checkpoints() to honor the config field (was always
  using the latest checkpoint regardless of the setting)

- Improve experiment ID generation: replace hard-coded parsing logic
  with a config-driven spec table (_OVERRIDE_COMPONENT_SPECS) that
  handles FFN, MoE, GQA, and Mamba in a unified way; fix None values
  being included in IDs (e.g. bypass_ffn_None_heads_4 → bypass_kv4);
  new format: bypass_ffn256_kv4, bypass_experts4, bypass_mamba, etc.

- Simplify checkpoint resume: replace wasteful state-dict dict-merge
  with load_state_dict(strict=False); add weights_only=True to all
  torch.load() calls

- Refactor stitched_model_factory: extract _initialize_student_model()
  helper to reduce bypass_factory_fn from ~250 to ~100 lines; document
  the block_loss_func keyword-argument contract (input=, target=)

- Add find_best_run_dir to checkpoint_utils; add NemotronH example to
  _copy_auto_map_code_files docstring

- Tests: add GPU test for checkpoint resume (find_last_ckpt_for_resume
  path); add unit tests for _set_keys_to_learn (all branches including
  hybrid Mamba/GQA filtering) and set_experiment_id (11 cases)

- Fix ruff N806 in main.py (N → n); fix PT006 in test_bypass_utils.py;
  update copyright year to 2026 on all new bypass files
@Separius Separius requested a review from a team as a code owner April 2, 2026 13:23
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py (1)

811-843: ⚠️ Potential issue | 🟠 Major

null overrides still crash for optional nested dataclasses.

The new item_overrides is None branch is bypassed when previous_value is None and _is_dataclass_type(item_type) is true, so an override like ...: null still becomes _get_dataclass_type(item_type)(**item_overrides) and raises at runtime. This is easy to hit for optional sub-configs that default to None.

Suggested fix
-            if previous_value is None and _is_dataclass_type(item_type):
-                new_value = _get_dataclass_type(item_type)(**item_overrides)
+            if item_overrides is None:
+                new_value = previous_value
+            elif previous_value is None and _is_dataclass_type(item_type):
+                assert isinstance(item_overrides, dict)
+                new_value = _get_dataclass_type(item_type)(**item_overrides)
             else:
                 new_value = override(previous_value, item_overrides)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py` around lines
811 - 843, The dataclass_override loop special-case instantiates a nested
dataclass even when the provided override is None, causing a crash; modify the
block in dataclass_override that handles "previous_value is None and
_is_dataclass_type(item_type)" to first check if item_overrides is None and in
that case set new_value = None (or call override(previous_value,
item_overrides)), otherwise instantiate with
_get_dataclass_type(item_type)(**item_overrides); keep the subsequent
check_type(new_value, item_type) and existing symbols (override,
dataclass_override, _is_dataclass_type, _get_dataclass_type, check_type) so
optional nested dataclass overrides that are null no longer raise.
🧹 Nitpick comments (3)
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py (1)

96-115: Reject overlapping outputs from multiple pruning mixins.

layer_out_state_dict.update(_layer_out) silently makes the final checkpoint depend on mixin order if two mixins emit the same state-dict key. Failing fast here is safer than letting one mixin overwrite the other.

Suggested guard
-            layer_out_state_dict.update(_layer_out)
+            overlapping_keys = layer_out_state_dict.keys() & _layer_out.keys()
+            if overlapping_keys:
+                raise ValueError(
+                    f"Pruning mixins produced overlapping keys for layer {layer_idx}: "
+                    f"{sorted(overlapping_keys)}"
+                )
+            layer_out_state_dict.update(_layer_out)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py` around lines
96 - 115, The loop over pruning mixins currently does
layer_out_state_dict.update(_layer_out) which allows later mixins to silently
overwrite keys from earlier ones; change this to detect overlapping keys and
fail fast: for each _mixin when you get _layer_out from prune_single_layer,
compute intersection = set(layer_out_state_dict.keys()) & set(_layer_out.keys())
and if intersection is non-empty raise a ValueError (or AssertionError) listing
the conflicting keys and the mixin identity (use _mixin or its type) instead of
updating; only call layer_out_state_dict.update(_layer_out) when intersection is
empty to ensure deterministic, non-overlapping outputs from prune_single_layer
across mixins.
modelopt/torch/puzzletron/tools/kd_model.py (1)

38-39: Add a zero/near-zero target regression test for this denominator change.

This adjustment mainly changes behavior when target has tiny norm, but tests/unit/torch/puzzletron/test_bypass_losses.py currently only covers random tensors. A focused zero-target case would keep this stabilization behavior from regressing unnoticed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/kd_model.py` around lines 38 - 39, Add a unit
test in tests/unit/torch/puzzletron/test_bypass_losses.py (e.g.,
test_kd_loss_zero_and_near_zero_target) that imports the kd loss implementation
from modelopt.torch.puzzletron.tools.kd_model, constructs both a zero target and
a near-zero target tensor, calls the code path that computes loss using the
expression containing F.mse_loss(input, target, reduction=reduction) /
(F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon),
and asserts the loss is finite and behaves stably (no division-by-zero, not
NaN/Inf) for both cases; use the same input tensor for both and check that
adding the epsilon in the denominator prevents regressions.
modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (1)

377-381: Consider more descriptive error handling for invalid block_loss_func.

If cfg.model_factory.block_loss_func is not one of the three supported values, a KeyError is raised with just the invalid key name. A more descriptive error would help users identify the misconfiguration quickly.

Suggested improvement
+    _BLOCK_LOSS_FUNCS = {
+        "normalized_mse_loss": normalized_mse_loss,
+        "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss,
+        "batched_normalized_mse_loss": batched_normalized_mse_loss,
+    }
+    loss_func_name = cfg.model_factory.block_loss_func
+    if loss_func_name not in _BLOCK_LOSS_FUNCS:
+        raise ValueError(
+            f"Unknown block_loss_func '{loss_func_name}'. "
+            f"Supported: {list(_BLOCK_LOSS_FUNCS.keys())}"
+        )
-    block_loss_func = {
-        "normalized_mse_loss": normalized_mse_loss,
-        "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss,
-        "batched_normalized_mse_loss": batched_normalized_mse_loss,
-    }[cfg.model_factory.block_loss_func]
+    block_loss_func = _BLOCK_LOSS_FUNCS[loss_func_name]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`
around lines 377 - 381, The current lookup for block_loss_func in
stitched_model_factory.py uses a direct dict index which raises an opaque
KeyError when cfg.model_factory.block_loss_func is invalid; replace it with a
guarded lookup: retrieve via dict.get or check membership first and raise a
ValueError with a clear message that includes the invalid value and the allowed
options (e.g., "normalized_mse_loss", "vectorwise_normalized_mse_loss",
"batched_normalized_mse_loss"); update the code around the block_loss_func
assignment (the dict and its use) so callers get a descriptive error instead of
a raw KeyError.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`:
- Around line 113-125: The checkpoint load/save is missing GradScaler state so
when use_grad_scaling=True resumed runs lose scaler state; update the save and
load paths around the StitchedModuleDescriptor handling to persist
grad_scaler.state_dict() (e.g., save to
stitched/{stitched_module_name}.grad_scaler.pth) when grad_scaler is not None
and on load (in the blocks that currently load optimizer state and in the
similar 165-171 block) call grad_scaler.load_state_dict(...) after constructing
or retrieving the module’s grad_scaler, using map_location=device, and guard
with the use_grad_scaling flag so scaler state is restored only when applicable.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 597-600: Log currently prints an empty block name because
submodule_name is never reassigned; update the mprint call in
stitched_model_factory.py (the one that prints Block {submodule_name}: ...) to
use a valid identifier such as student_stitched_module_name (or module_name) so
the block name is meaningful; locate the mprint invocation and replace
submodule_name with student_stitched_module_name (ensuring
student_stitched_module_name is in scope) and keep the rest of the
message/parameter counting intact.
- Around line 417-428: The code assumes owned_block_indexes is non-empty before
calling min()/max(), which will raise ValueError if a rank owns no blocks; in
the block around min_owned_index/max_owned_index in stitched_model_factory.py,
first check if not owned_block_indexes and handle it defensively (e.g., set
prev_rank and next_rank to None or raise a clear, explanatory error) instead of
calling min()/max(); update the logic that computes prev_rank and next_rank
using model_blocks_process_ownership and all_block_indices to only run when
owned_block_indexes is non-empty so misconfiguration yields a clear message or
safe defaults.

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 837-838: The code reads source_datasets_to_discard from cfg.bypass
root but the new config nests it under bypass.data; update the dataloader calls
that set source_datasets_to_discard (and any similar occurrences) to read from
cfg.bypass.data.get("source_datasets_to_discard", tuple()) instead of
cfg.bypass.get(...), leaving bos_rate as cfg.bypass.data.bos_rate; search for
the occurrences that set source_datasets_to_discard (the two places mentioned
around the calls that also use bos_rate) and replace them to use cfg.bypass.data
so the discard list becomes configurable.
- Around line 252-253: The parameter skip_first_batches is never applied: after
creating the batch iterator from the ConstantLengthDataset/dataloader you must
advance that iterator by skip_first_batches before entering the training loop
(e.g., consume the iterator with next(...) in a short loop or use
itertools.islice to drop the first N items); update the code paths where
skip_first_batches is accepted (the occurrences around skip_first_batches in
training_loop.py and the second occurrence at lines ~329-330) to consume the
iterator accordingly so resumed runs do not replay from batch 0.
- Around line 349-350: The loop exit condition uses a 1-based counter and
currently uses >=, causing it to stop one step too early; update the check in
the training loop that references cfg.bypass.step_num and
cfg.bypass.training.max_steps so it breaks only once step_num has passed the
budget (use > instead of >=) so the final scheduled step runs.
- Around line 103-107: The AutoConfig.from_pretrained call inside
run_bypassed_training bypasses the earlier trust_remote_code decision; update
the AutoConfig.from_pretrained invocation in training_loop.py (the block that
imports HFAutoConfig and sets teacher_hf_cfg/teacher_intermediate_size) to pass
the same trust_remote_code flag you query earlier (the
descriptor/trust_remote_code value used by run_bypassed_training) so
remote-code-required models load consistently and safely. Locate the
AutoConfig.from_pretrained usage and add the trust_remote_code argument (sourced
from the existing hydra_cfg/descriptor/trust_remote_code variable) when calling
from_pretrained, ensuring the call mirrors the trusted-remote-code decision made
elsewhere.

In `@tests/gpu/torch/puzzletron/test_puzzletron.py`:
- Around line 209-215: The test incorrectly computes per-rank FFN counts from
hidden layer count (total_layers = max(2, size)); instead compute the actual
number of prunable FFN blocks (e.g., scan the model's layer names or modules to
count FFN/prunable blocks rather than using hidden-layer count) into
total_ffn_blocks, then compute layers_this_rank = total_ffn_blocks // size + (1
if rank < total_ffn_blocks % size else 0) and assert len(layer_names) ==
layers_this_rank (allowing 0 for ranks that only own Mamba blocks); update the
variables total_layers/layers_this_rank and reference layer_names when making
this change.

---

Outside diff comments:
In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 811-843: The dataclass_override loop special-case instantiates a
nested dataclass even when the provided override is None, causing a crash;
modify the block in dataclass_override that handles "previous_value is None and
_is_dataclass_type(item_type)" to first check if item_overrides is None and in
that case set new_value = None (or call override(previous_value,
item_overrides)), otherwise instantiate with
_get_dataclass_type(item_type)(**item_overrides); keep the subsequent
check_type(new_value, item_type) and existing symbols (override,
dataclass_override, _is_dataclass_type, _get_dataclass_type, check_type) so
optional nested dataclass overrides that are null no longer raise.

---

Nitpick comments:
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 377-381: The current lookup for block_loss_func in
stitched_model_factory.py uses a direct dict index which raises an opaque
KeyError when cfg.model_factory.block_loss_func is invalid; replace it with a
guarded lookup: retrieve via dict.get or check membership first and raise a
ValueError with a clear message that includes the invalid value and the allowed
options (e.g., "normalized_mse_loss", "vectorwise_normalized_mse_loss",
"batched_normalized_mse_loss"); update the code around the block_loss_func
assignment (the dict and its use) so callers get a descriptive error instead of
a raw KeyError.

In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 96-115: The loop over pruning mixins currently does
layer_out_state_dict.update(_layer_out) which allows later mixins to silently
overwrite keys from earlier ones; change this to detect overlapping keys and
fail fast: for each _mixin when you get _layer_out from prune_single_layer,
compute intersection = set(layer_out_state_dict.keys()) & set(_layer_out.keys())
and if intersection is non-empty raise a ValueError (or AssertionError) listing
the conflicting keys and the mixin identity (use _mixin or its type) instead of
updating; only call layer_out_state_dict.update(_layer_out) when intersection is
empty to ensure deterministic, non-overlapping outputs from prune_single_layer
across mixins.

In `@modelopt/torch/puzzletron/tools/kd_model.py`:
- Around line 38-39: Add a unit test in
tests/unit/torch/puzzletron/test_bypass_losses.py (e.g.,
test_kd_loss_zero_and_near_zero_target) that imports the kd loss implementation
from modelopt.torch.puzzletron.tools.kd_model, constructs both a zero target and
a near-zero target tensor, calls the code path that computes loss using the
expression containing F.mse_loss(input, target, reduction=reduction) /
(F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon),
and asserts the loss is finite and behaves stably (no division-by-zero, not
NaN/Inf) for both cases; use the same input tensor for both and check that
adding the epsilon in the denominator prevents regressions.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6cb35ef5-41ea-4f6d-990a-791e2c99b812

📥 Commits

Reviewing files that changed from the base of the PR and between e018ca0 and 2b99327.

📒 Files selected for processing (90)
  • examples/puzzletron/BYPASS.md
  • examples/puzzletron/README.md
  • examples/puzzletron/configs/bypass/defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/bypass/defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/pruning/defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/realize_model/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/scoring/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/validate_model_defaults.yaml
  • examples/puzzletron/configs/validate_solutions_defaults.yaml
  • examples/puzzletron/main.py
  • modelopt/torch/puzzletron/bypass_distillation/__init__.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/data_classes.py
  • modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
  • modelopt/torch/puzzletron/bypass_distillation/training_loop.py
  • modelopt/torch/puzzletron/dataset/prepare_dataset.py
  • modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
  • modelopt/torch/puzzletron/pruning/pruning_utils.py
  • modelopt/torch/puzzletron/sewing_kit/utils.py
  • modelopt/torch/puzzletron/tools/bypassed_training/child_init.py
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
  • modelopt/torch/puzzletron/tools/kd_model.py
  • modelopt/torch/puzzletron/utils/data/dataloaders.py
  • modelopt/torch/puzzletron/utils/parsing.py
  • modelopt/torch/utils/plugins/transformers_dataset.py
  • tests/_test_utils/torch/puzzletron/utils.py
  • tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py
  • tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py
  • tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml
  • tests/gpu/torch/puzzletron/test_bypass.py
  • tests/gpu/torch/puzzletron/test_puzzletron.py
  • tests/unit/torch/puzzletron/__init__.py
  • tests/unit/torch/puzzletron/test_bypass_losses.py
  • tests/unit/torch/puzzletron/test_bypass_utils.py
✅ Files skipped from review due to trivial changes (67)
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml
  • tests/unit/torch/puzzletron/init.py
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml
  • modelopt/torch/puzzletron/dataset/prepare_dataset.py
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml
  • modelopt/torch/utils/plugins/transformers_dataset.py
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/bypass/defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/realize_model/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/bypass/defaults.yaml
  • tests/_test_utils/torch/puzzletron/utils.py
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml
  • tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml
  • examples/puzzletron/configs/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml
  • examples/puzzletron/configs/scoring/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/README.md
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
  • examples/puzzletron/configs/validate_model_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml
  • modelopt/torch/puzzletron/bypass_distillation/data_classes.py
  • modelopt/torch/puzzletron/bypass_distillation/init.py
  • examples/puzzletron/configs/bypass/defaults.yaml
  • examples/puzzletron/BYPASS.md
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml
  • tests/unit/torch/puzzletron/test_bypass_losses.py
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml
🚧 Files skipped from review as they are similar to previous changes (6)
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml
  • modelopt/torch/puzzletron/utils/parsing.py
  • modelopt/torch/puzzletron/pruning/pruning_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
  • modelopt/torch/puzzletron/sewing_kit/utils.py
  • modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py

Comment on lines +113 to +125
if optimizer is not None:
optimizer_state_path = (
load_dir / "stitched" / f"{stitched_module_name}.optimizer_state.pth"
)
if verbose:
mprint(
f"Loading optimizer state for module {stitched_module_name} from {optimizer_state_path}"
)
loaded_optimizer_state = torch.load(
optimizer_state_path, map_location=device, weights_only=True
)
optimizer.load_state_dict(loaded_optimizer_state)
del loaded_optimizer_state
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Persist GradScaler state as part of the bypass checkpoint.

StitchedModuleDescriptor includes grad_scaler, but the checkpoint only saves/restores model and optimizer state. With use_grad_scaling=True, a resumed run restarts from a fresh scale factor instead of the checkpointed training state.

💡 Suggested fix
         if optimizer is not None:
             optimizer_state_path = (
                 load_dir / "stitched" / f"{stitched_module_name}.optimizer_state.pth"
             )
             if verbose:
                 mprint(
                     f"Loading optimizer state for module {stitched_module_name} from {optimizer_state_path}"
                 )
             loaded_optimizer_state = torch.load(
                 optimizer_state_path, map_location=device, weights_only=True
             )
             optimizer.load_state_dict(loaded_optimizer_state)
             del loaded_optimizer_state
+
+        grad_scaler = stitched_module_descriptor.grad_scaler
+        if grad_scaler is not None:
+            scaler_state_path = load_dir / "stitched" / f"{stitched_module_name}.grad_scaler_state.pth"
+            loaded_scaler_state = torch.load(
+                scaler_state_path, map_location=device, weights_only=True
+            )
+            grad_scaler.load_state_dict(loaded_scaler_state)
+            del loaded_scaler_state
...
         if optimizer is not None:
             optimizer_state_path = save_dir / f"{stitched_module_name}.optimizer_state.pth"
             if verbose:
                 mprint(
                     f"Saving optimizer state for module {stitched_module_name} to {optimizer_state_path}"
                 )
             _save_local_file(optimizer.state_dict(), optimizer_state_path, overwrite=overwrite)
+
+        grad_scaler = stitched_module_descriptor.grad_scaler
+        if grad_scaler is not None:
+            scaler_state_path = save_dir / f"{stitched_module_name}.grad_scaler_state.pth"
+            _save_local_file(grad_scaler.state_dict(), scaler_state_path, overwrite=overwrite)

Also applies to: 165-171

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`
around lines 113 - 125, The checkpoint load/save is missing GradScaler state so
when use_grad_scaling=True resumed runs lose scaler state; update the save and
load paths around the StitchedModuleDescriptor handling to persist
grad_scaler.state_dict() (e.g., save to
stitched/{stitched_module_name}.grad_scaler.pth) when grad_scaler is not None
and on load (in the blocks that currently load optimizer state and in the
similar 165-171 block) call grad_scaler.load_state_dict(...) after constructing
or retrieving the module’s grad_scaler, using map_location=device, and guard
with the use_grad_scaling flag so scaler state is restored only when applicable.

Comment on lines +417 to +428
min_owned_index = min(owned_block_indexes)
max_owned_index = max(owned_block_indexes)
prev_rank: Optional[int] = (
None
if min_owned_index == min(all_block_indices)
else model_blocks_process_ownership[min_owned_index - 1]
)
next_rank: Optional[int] = (
None
if max_owned_index == max(all_block_indices)
else model_blocks_process_ownership[max_owned_index + 1]
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential ValueError if a rank owns no blocks.

min(owned_block_indexes) and max(owned_block_indexes) will raise ValueError if owned_block_indexes is empty. While the current design likely ensures every rank owns at least one block, defensive handling would prevent cryptic errors during misconfiguration.

Suggested defensive check
+    if not owned_block_indexes:
+        raise ValueError(
+            f"Rank {dist.rank()} owns no blocks. Check model_blocks_process_ownership mapping."
+        )
+
     min_owned_index = min(owned_block_indexes)
     max_owned_index = max(owned_block_indexes)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
min_owned_index = min(owned_block_indexes)
max_owned_index = max(owned_block_indexes)
prev_rank: Optional[int] = (
None
if min_owned_index == min(all_block_indices)
else model_blocks_process_ownership[min_owned_index - 1]
)
next_rank: Optional[int] = (
None
if max_owned_index == max(all_block_indices)
else model_blocks_process_ownership[max_owned_index + 1]
)
if not owned_block_indexes:
raise ValueError(
f"Rank {dist.rank()} owns no blocks. Check model_blocks_process_ownership mapping."
)
min_owned_index = min(owned_block_indexes)
max_owned_index = max(owned_block_indexes)
prev_rank: Optional[int] = (
None
if min_owned_index == min(all_block_indices)
else model_blocks_process_ownership[min_owned_index - 1]
)
next_rank: Optional[int] = (
None
if max_owned_index == max(all_block_indices)
else model_blocks_process_ownership[max_owned_index + 1]
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`
around lines 417 - 428, The code assumes owned_block_indexes is non-empty before
calling min()/max(), which will raise ValueError if a rank owns no blocks; in
the block around min_owned_index/max_owned_index in stitched_model_factory.py,
first check if not owned_block_indexes and handle it defensively (e.g., set
prev_rank and next_rank to None or raise a clear, explanatory error) instead of
calling min()/max(); update the logic that computes prev_rank and next_rank
using model_blocks_process_ownership and all_block_indices to only run when
owned_block_indexes is non-empty so misconfiguration yields a clear message or
safe defaults.

Comment on lines +597 to +600
mprint(
f"Block {submodule_name}: {len(trainable_params)} trainable parameter tensors "
f"({sum(p.numel() for p in trainable_params.values()):,} params)"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Log message will always show empty block name.

submodule_name is initialized to "" at line 449 and never reassigned within the loop. The log message "Block : ..." will always display an empty block name. Consider using student_stitched_module_name (e.g., block_0) or module_name for clarity.

Suggested fix
             mprint(
-                f"Block {submodule_name}: {len(trainable_params)} trainable parameter tensors "
+                f"Block {student_stitched_module_name}: {len(trainable_params)} trainable parameter tensors "
                 f"({sum(p.numel() for p in trainable_params.values()):,} params)"
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
mprint(
f"Block {submodule_name}: {len(trainable_params)} trainable parameter tensors "
f"({sum(p.numel() for p in trainable_params.values()):,} params)"
)
mprint(
f"Block {student_stitched_module_name}: {len(trainable_params)} trainable parameter tensors "
f"({sum(p.numel() for p in trainable_params.values()):,} params)"
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`
around lines 597 - 600, Log currently prints an empty block name because
submodule_name is never reassigned; update the mprint call in
stitched_model_factory.py (the one that prints Block {submodule_name}: ...) to
use a valid identifier such as student_stitched_module_name (or module_name) so
the block name is meaningful; locate the mprint invocation and replace
submodule_name with student_stitched_module_name (ensuring
student_stitched_module_name is in scope) and keep the rest of the
message/parameter counting intact.

Comment on lines +103 to +107
if do_ffn or do_attn or do_blk:
from transformers import AutoConfig as HFAutoConfig

teacher_hf_cfg = HFAutoConfig.from_pretrained(str(hydra_cfg.teacher_dir))
teacher_intermediate_size = getattr(teacher_hf_cfg, "intermediate_size", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '95,120p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1567


🏁 Script executed:

rg "requires_trust_remote_code|trust_remote_code" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 614


🏁 Script executed:

rg "ModelDescriptorFactory" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 539


🏁 Script executed:

cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '1,110p'

Repository: NVIDIA/Model-Optimizer

Length of output: 5786


🏁 Script executed:

rg "def requires_trust_remote_code" modelopt/ -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 1441


🏁 Script executed:

cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '73,250p'

Repository: NVIDIA/Model-Optimizer

Length of output: 9638


🏁 Script executed:

rg "def launch_bypass_distillation" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -A 50

Repository: NVIDIA/Model-Optimizer

Length of output: 2748


🏁 Script executed:

rg "hydra_cfg.descriptor|cfg.descriptor" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 344


🏁 Script executed:

cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '240,300p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2827


🏁 Script executed:

rg "descriptor\s*=" modelopt/torch/puzzletron/bypass_distillation/training_loop.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 558


🏁 Script executed:

cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '290,360p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3106


🏁 Script executed:

rg "def run_bypassed_training" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -A 30

Repository: NVIDIA/Model-Optimizer

Length of output: 1259


Thread trust_remote_code through the auto-config probe.

run_bypassed_training() queries the descriptor for trust_remote_code, but this auto-config path at lines 103–107 bypasses that and calls AutoConfig.from_pretrained() with default (unsafe) behavior. Models requiring remote code execution will fail inconsistently depending on which path loads them.

💡 Suggested fix
         if do_ffn or do_attn or do_blk:
             from transformers import AutoConfig as HFAutoConfig
 
-            teacher_hf_cfg = HFAutoConfig.from_pretrained(str(hydra_cfg.teacher_dir))
+            descriptor = ModelDescriptorFactory.get(hydra_cfg.descriptor)
+            trust_remote_code = descriptor.requires_trust_remote_code()
+            teacher_hf_cfg = HFAutoConfig.from_pretrained(
+                str(hydra_cfg.teacher_dir),
+                trust_remote_code=trust_remote_code,
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py` around lines
103 - 107, The AutoConfig.from_pretrained call inside run_bypassed_training
bypasses the earlier trust_remote_code decision; update the
AutoConfig.from_pretrained invocation in training_loop.py (the block that
imports HFAutoConfig and sets teacher_hf_cfg/teacher_intermediate_size) to pass
the same trust_remote_code flag you query earlier (the
descriptor/trust_remote_code value used by run_bypassed_training) so
remote-code-required models load consistently and safely. Locate the
AutoConfig.from_pretrained usage and add the trust_remote_code argument (sourced
from the existing hydra_cfg/descriptor/trust_remote_code variable) when calling
from_pretrained, ensuring the call mirrors the trusted-remote-code decision made
elsewhere.

Comment on lines +252 to +253
skip_first_batches: int = 0,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

skip_first_batches is currently a no-op.

The iterator is created and consumed immediately, but never advanced by skip_first_batches. On resume that replays the training stream from batch 0, because ConstantLengthDataset does not persist iterator position.

💡 Suggested fix
     train_iterator = iter(train_dataloader)
+    if dist.is_master() and skip_first_batches:
+        for _ in range(skip_first_batches):
+            next(train_iterator)
 
     mprint("Waiting for everyone before training starts")

Also applies to: 329-330

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py` around lines
252 - 253, The parameter skip_first_batches is never applied: after creating the
batch iterator from the ConstantLengthDataset/dataloader you must advance that
iterator by skip_first_batches before entering the training loop (e.g., consume
the iterator with next(...) in a short loop or use itertools.islice to drop the
first N items); update the code paths where skip_first_batches is accepted (the
occurrences around skip_first_batches in training_loop.py and the second
occurrence at lines ~329-330) to consume the iterator accordingly so resumed
runs do not replay from batch 0.

Comment on lines +349 to +350
if cfg.bypass.step_num >= cfg.bypass.training.max_steps:
if (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Stop after the last scheduled step, not before it.

With the current 1-based step_num, max_steps=1 exits before the first optimizer step and max_steps=2 only executes one step. This should break once step_num has moved past the budget, not when it is equal to it.

💡 Suggested fix
-        if cfg.bypass.step_num >= cfg.bypass.training.max_steps:
+        if cfg.bypass.step_num > cfg.bypass.training.max_steps:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if cfg.bypass.step_num >= cfg.bypass.training.max_steps:
if (
if cfg.bypass.step_num > cfg.bypass.training.max_steps:
if (
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py` around lines
349 - 350, The loop exit condition uses a 1-based counter and currently uses >=,
causing it to stop one step too early; update the check in the training loop
that references cfg.bypass.step_num and cfg.bypass.training.max_steps so it
breaks only once step_num has passed the budget (use > instead of >=) so the
final scheduled step runs.

Comment on lines +837 to +838
source_datasets_to_discard=cfg.bypass.get("source_datasets_to_discard", tuple()),
bos_rate=cfg.bypass.data.bos_rate,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Read source_datasets_to_discard from bypass.data.

The new bypass config nests this field under bypass.data, but both dataloader calls read from the bypass root. As written, the discard list is effectively impossible to configure.

💡 Suggested fix
-            source_datasets_to_discard=cfg.bypass.get("source_datasets_to_discard", tuple()),
+            source_datasets_to_discard=cfg.bypass.data.get(
+                "source_datasets_to_discard", tuple()
+            ),
...
-                source_datasets_to_discard=cfg.bypass.get("source_datasets_to_discard", tuple()),
+                source_datasets_to_discard=cfg.bypass.data.get(
+                    "source_datasets_to_discard", tuple()
+                ),

Also applies to: 858-859

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py` around lines
837 - 838, The code reads source_datasets_to_discard from cfg.bypass root but
the new config nests it under bypass.data; update the dataloader calls that set
source_datasets_to_discard (and any similar occurrences) to read from
cfg.bypass.data.get("source_datasets_to_discard", tuple()) instead of
cfg.bypass.get(...), leaving bos_rate as cfg.bypass.data.bos_rate; search for
the occurrences that set source_datasets_to_discard (the two places mentioned
around the calls that also use bos_rate) and replace them to use cfg.bypass.data
so the discard list becomes configurable.

Comment on lines +209 to 215
# The test model has num_hidden_layers = max(2, size), so every rank owns at least
# one layer. Compute the actual expected count for *this* rank.
total_layers = max(2, size)
layers_this_rank = total_layers // size + (1 if rank < total_layers % size else 0)
assert len(layer_names) == layers_this_rank, (
f"Expected {layers_this_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The per-rank FFN count is still wrong for hybrid models.

total_layers = max(2, size) counts hidden layers, not prunable FFN blocks. This file already documents nvidia/NVIDIA-Nemotron-Nano-12B-v2 as having only one FFN layer, so a rank that owns only Mamba blocks can legitimately have len(layer_names) == 0.

💡 Suggested fix
-        total_layers = max(2, size)
-        layers_this_rank = total_layers // size + (1 if rank < total_layers % size else 0)
-        assert len(layer_names) == layers_this_rank, (
-            f"Expected {layers_this_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}"
-        )
+        total_layers = max(2, size)
+        if len(expected) == total_layers:
+            layers_this_rank = total_layers // size + (1 if rank < total_layers % size else 0)
+            assert len(layer_names) == layers_this_rank, (
+                f"Expected {layers_this_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/puzzletron/test_puzzletron.py` around lines 209 - 215, The
test incorrectly computes per-rank FFN counts from hidden layer count
(total_layers = max(2, size)); instead compute the actual number of prunable FFN
blocks (e.g., scan the model's layer names or modules to count FFN/prunable
blocks rather than using hidden-layer count) into total_ffn_blocks, then compute
layers_this_rank = total_ffn_blocks // size + (1 if rank < total_ffn_blocks %
size else 0) and assert len(layer_names) == layers_this_rank (allowing 0 for
ranks that only own Mamba blocks); update the variables
total_layers/layers_this_rank and reference layer_names when making this change.

- Extract common setup preamble (dist.setup, register_hydra_resolvers,
  hydra config load, _total_steps) into _setup() helper in main.py to
  eliminate duplication between run_full_puzzletron and run_mip_only
- Rename uppercase N → n in main.py and puzzletron_nas_plugin.py
- Remove unused gqa_factory_fn and moe_factory_fn aliases from
  stitched_model_factory.py
- Improve BYPASS.md: clarify when to run bypass (KV head reduction,
  no_op blocks, extreme FFN/MoE compression); fix coupled BLD cost
  description (N×M runs vs N+M, not harder to optimise)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/puzzletron/main.py (1)

102-135: ⚠️ Potential issue | 🟠 Major

Ensure distributed cleanup runs on failure paths.

If convert/search/sweep/MIP raises, dist.cleanup() is skipped. In multi-GPU flows this can leave process groups hanging. Wrap execution in try/finally in both runners.

Proposed fix
 def run_full_puzzletron(hydra_config_path: str):
@@
     hydra_cfg, hydra_config_dir, hydra_config_name, n = _setup(hydra_config_path)
 
     mprint(f"Puzzletron Progress 1/{n}: starting puzzletron pipeline")
-
-    # Convert model (convert from HF to DeciLM, score pruning activations,
-    # prune the model and save pruned checkpoints)
-    input_model = PuzzletronModel()
-    converted_model = mtn.convert(
-        input_model,
-        mode=[
-            (
-                "puzzletron",
-                {
-                    "puzzle_dir": str(hydra_cfg.puzzle_dir),
-                    "input_model_path": hydra_cfg.input_hf_model_path,
-                    "hydra_config_dir": hydra_config_dir,
-                    "hydra_config_name": hydra_config_name,
-                    "dataset_path": str(hydra_cfg.dataset_path),
-                },
-            )
-        ],
-    )
-
-    # Run NAS search (build replacement library and compute stats,
-    # compute one block scores, run MIP and realize models)
-    mtn.search(
-        converted_model,
-        constraints={},  # this is not used as the search space is defined in the hydra config
-        dummy_input=None,  # Not used
-        config={},  # this is not used as the search space is defined in the hydra config
-    )
-
-    dist.cleanup()
+    try:
+        # Convert model (convert from HF to DeciLM, score pruning activations,
+        # prune the model and save pruned checkpoints)
+        input_model = PuzzletronModel()
+        converted_model = mtn.convert(
+            input_model,
+            mode=[
+                (
+                    "puzzletron",
+                    {
+                        "puzzle_dir": str(hydra_cfg.puzzle_dir),
+                        "input_model_path": hydra_cfg.input_hf_model_path,
+                        "hydra_config_dir": hydra_config_dir,
+                        "hydra_config_name": hydra_config_name,
+                        "dataset_path": str(hydra_cfg.dataset_path),
+                    },
+                )
+            ],
+        )
+
+        # Run NAS search (build replacement library and compute stats,
+        # compute one block scores, run MIP and realize models)
+        mtn.search(
+            converted_model,
+            constraints={},  # this is not used as the search space is defined in the hydra config
+            dummy_input=None,  # Not used
+            config={},  # this is not used as the search space is defined in the hydra config
+        )
+    finally:
+        dist.cleanup()
     mprint(f"Puzzletron Progress {n}/{n}: puzzletron pipeline completed (multi-gpu)")
@@
 def run_mip_only(hydra_config_path: str):
@@
-    # Check if sweep mode is enabled
-    if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False):
-        mprint(
-            f"Puzzletron Progress {mip_step}/{n}: running MIP sweep for multiple compression rates (multi-gpu)"
-        )
-        sweep.run_mip_sweep(hydra_cfg)
-    else:
-        # mip_and_realize_models (distributed processing)
-        # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API
-        mprint(f"Puzzletron Progress {mip_step}/{n}: running MIP and realizing models (multi-gpu)")
-        mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)
-
-    dist.cleanup()
+    try:
+        # Check if sweep mode is enabled
+        if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False):
+            mprint(
+                f"Puzzletron Progress {mip_step}/{n}: running MIP sweep for multiple compression rates (multi-gpu)"
+            )
+            sweep.run_mip_sweep(hydra_cfg)
+        else:
+            # mip_and_realize_models (distributed processing)
+            # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API
+            mprint(f"Puzzletron Progress {mip_step}/{n}: running MIP and realizing models (multi-gpu)")
+            mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)
+    finally:
+        dist.cleanup()
     mprint(f"Puzzletron Progress {n}/{n}: puzzletron pipeline completed (multi-gpu)")

Also applies to: 147-163

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/puzzletron/main.py` around lines 102 - 135, The current flow calls
dist.cleanup() after running mtn.convert and mtn.search but if
mtn.convert/mtn.search (or any subsequent step) raises an exception the cleanup
is skipped; wrap the multi-GPU pipeline (from _setup through
mtn.search/mtn.sweep/mtn.MIP calls around lines that create PuzzletronModel,
call mtn.convert and mtn.search) in a try/finally block so dist.cleanup() always
runs, and apply the same try/finally pattern to the other runner block
referenced around lines 147-163; ensure the try encompasses all work that
requires the distributed group and the finally calls dist.cleanup()
unconditionally.
🧹 Nitpick comments (1)
modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (1)

377-381: Consider clearer error handling for unknown block_loss_func.

If cfg.model_factory.block_loss_func is not one of the three expected values, a KeyError is raised with a cryptic message. A more informative error would help users diagnose configuration issues.

♻️ Suggested improvement
-    block_loss_func = {
+    _BLOCK_LOSS_FUNCS = {
         "normalized_mse_loss": normalized_mse_loss,
         "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss,
         "batched_normalized_mse_loss": batched_normalized_mse_loss,
-    }[cfg.model_factory.block_loss_func]
+    }
+    loss_func_name = cfg.model_factory.block_loss_func
+    if loss_func_name not in _BLOCK_LOSS_FUNCS:
+        raise ValueError(
+            f"Unknown block_loss_func '{loss_func_name}'. "
+            f"Expected one of: {list(_BLOCK_LOSS_FUNCS.keys())}"
+        )
+    block_loss_func = _BLOCK_LOSS_FUNCS[loss_func_name]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`
around lines 377 - 381, The current lookup for block_loss_func using a dict
keyed by cfg.model_factory.block_loss_func can raise a cryptic KeyError; update
the code around block_loss_func (in stitched_model_factory.py) to explicitly
validate cfg.model_factory.block_loss_func against the allowed names
("normalized_mse_loss", "vectorwise_normalized_mse_loss",
"batched_normalized_mse_loss") and raise a clear ValueError that includes the
invalid value and the list of valid options; reference the existing functions
normalized_mse_loss, vectorwise_normalized_mse_loss, and
batched_normalized_mse_loss when constructing the mapping and error message so
users can quickly see the supported choices.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/puzzletron/main.py`:
- Around line 99-101: The docstring parameter name is incorrect — replace the
documented `config_path` with the actual function parameter `hydra_config_path`
and update its description to match (e.g., "Path to the YAML configuration
file") so the `hydra_config_path` argument in the function signature and the
docstring are consistent; locate the docstring in examples/puzzletron/main.py
near the function that accepts `hydra_config_path` and make this single-name
correction.

---

Outside diff comments:
In `@examples/puzzletron/main.py`:
- Around line 102-135: The current flow calls dist.cleanup() after running
mtn.convert and mtn.search but if mtn.convert/mtn.search (or any subsequent
step) raises an exception the cleanup is skipped; wrap the multi-GPU pipeline
(from _setup through mtn.search/mtn.sweep/mtn.MIP calls around lines that create
PuzzletronModel, call mtn.convert and mtn.search) in a try/finally block so
dist.cleanup() always runs, and apply the same try/finally pattern to the other
runner block referenced around lines 147-163; ensure the try encompasses all
work that requires the distributed group and the finally calls dist.cleanup()
unconditionally.

---

Nitpick comments:
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 377-381: The current lookup for block_loss_func using a dict keyed
by cfg.model_factory.block_loss_func can raise a cryptic KeyError; update the
code around block_loss_func (in stitched_model_factory.py) to explicitly
validate cfg.model_factory.block_loss_func against the allowed names
("normalized_mse_loss", "vectorwise_normalized_mse_loss",
"batched_normalized_mse_loss") and raise a clear ValueError that includes the
invalid value and the list of valid options; reference the existing functions
normalized_mse_loss, vectorwise_normalized_mse_loss, and
batched_normalized_mse_loss when constructing the mapping and error message so
users can quickly see the supported choices.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 99826f98-f6fb-41d2-a78c-5c40bec6c4c9

📥 Commits

Reviewing files that changed from the base of the PR and between 351b44e and 346408b.

📒 Files selected for processing (3)
  • examples/puzzletron/main.py
  • modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
  • modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py

Comment on lines +99 to +101
Args:
config_path: Path to the YAML configuration file
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix docstring argument name mismatch.

Line 100 documents config_path, but the function argument is hydra_config_path. Please align the docstring to avoid confusion.

Proposed fix
 def run_full_puzzletron(hydra_config_path: str):
@@
     Args:
-        config_path: Path to the YAML configuration file
+        hydra_config_path: Path to the YAML configuration file
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Args:
config_path: Path to the YAML configuration file
"""
Args:
hydra_config_path: Path to the YAML configuration file
"""
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/puzzletron/main.py` around lines 99 - 101, The docstring parameter
name is incorrect — replace the documented `config_path` with the actual
function parameter `hydra_config_path` and update its description to match
(e.g., "Path to the YAML configuration file") so the `hydra_config_path`
argument in the function signature and the docstring are consistent; locate the
docstring in examples/puzzletron/main.py near the function that accepts
`hydra_config_path` and make this single-name correction.

Extract four self-contained blocks from the 436-line train() function
into named helpers, reducing it to ~290 lines:

- _save_final_checkpoint(): saves the final checkpoint when max_steps
  is reached and cleans up old iter-* checkpoints
- _log_training_stats(): master-only block that processes loss history
  in log_interval chunks, updates best-loss tracking, prints tables
  via format_stitched_losses, and optionally logs to W&B
- _run_validation(): runs the distributed validation pipeline,
  broadcasts val_loss from the last rank, and saves the best
  checkpoint if validation loss improved
- _save_interval_checkpoint(): handles step-interval and time-based
  checkpoint saving, including kill_after_first_save semantics

No behavioral changes — pure mechanical extraction.
@Separius
Copy link
Copy Markdown
Author

Separius commented Apr 2, 2026

@cjluo-nv addressed all the points (thanks again for the great review)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants